# Kernel: hgemm_tn_128x16

# ******************************************************************************
# Copyright 2014-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************


<CONSTANT_MAPPING>
    addr_zero  : 4x<128*8*2 + 16*8*2 + 0>

    gridDimA : c[0x0][0x14]
    gridDimB : c[0x0][0x18]

    param_C[0]      : c[0x0][0x140]
    param_C[1]      : c[0x0][0x144]
    param_A[0]      : c[0x0][0x148]
    param_A[1]      : c[0x0][0x14c]
    param_B[0]      : c[0x0][0x150]
    param_B[1]      : c[0x0][0x154]
    param_alpha     : c[0x0][0x158]
    param_beta      : c[0x0][0x15c]
    param_flags     : c[0x0][0x160]
    param_lda8      : c[0x0][0x164]
    param_ldb8      : c[0x0][0x168]
    param_ldc       : c[0x0][0x16c]
    param_m         : c[0x0][0x170]
    param_n         : c[0x0][0x174]
    param_k         : c[0x0][0x178]
    param_ldaz      : c[0x0][0x17c]
    param_ldbz      : c[0x0][0x180]
    param_ldcz      : c[0x0][0x184]
    param_loops     : c[0x0][0x188]
</CONSTANT_MAPPING>

<REGISTER_MAPPING>

    16-17 : Rand<0-1>

    18-47 ~ lda, ldb, ldaz, ldbz, lda8, ldb8, ta, tb, tid1, tid96, tidAX, tidBX, tidY, txa, txb, dimA, flag

    0-15  : czero<00-15>

    3, 2,11,10 : cx<0-3>y0
    7, 6,15,14 : cx<0-3>y1
    1, 0, 9, 8 : cx<0-3>y2
    5, 4,13,12 : cx<0-3>y3

    16-23   : j0Ay<0-3>, j0Bx<0-3>
    24-31   : j1Ay<0-3>, j1Bx<0-3>
    32-39   : j2Ay<0-3>, j2Bx<0-3>
    40-47   : j3Ay<0-3>, j3Bx<0-3>

    48-55   : load0A<0-7>
    56-63   : load1A<0-7>
    64-71   : load2A<0-7>
    72-79   : load3A<0-7>

    80-83   : load<0-3>B

    84-87   : track0A<0-1>, track0B<0-1>
    88-91   : track1A<0-1>, track1B<0-1>
    92-95   : track2A<0-1>, track2B<0-1>
    96-99   : track3A<0-1>, track3B<0-1>

    100-104 ~ writeAs, writeBs, k, lda32, ldb32
    105-112 ~ readAs, readBs, tid, blkA, blkB, blkZ, tbid, seed

    16-25   : c<0-3>, b<0-1>, d3, d2, d1, d0
    26-27   : Cy<0-1>
    28-104  ~ ldc, ldcz, ldc1, writeCs, readCs, tidCX, tidCY, cx, cy, ci, xmad_c, alpha, beta, flags, tid31, lfsr<0-2>, exp<0-3>, rand<0-3>, lfsr<0-2>_1, lfsr<0-2>_2, clk_shf1, clk_shf2

</REGISTER_MAPPING>

--:-:1:-:1      S2R tid,  SR_TID.X;
--:-:2:-:1      S2R blkA, SR_CTAID.Y;
--:-:3:-:1      S2R blkB, SR_CTAID.Z;
--:-:4:-:1      S2R blkZ, SR_CTAID.X;

<SCHEDULE_BLOCK>
--:-:-:-:1      MOV k, param_k;
--:-:-:-:1      STS.128 [addr_zero], RZ;

--:-:-:-:1      LDS.U.128 czero00, [addr_zero];
--:-:-:-:1      LDS.U.128 czero04, [addr_zero];
--:-:-:-:1      LDS.U.128 czero08, [addr_zero];
--:-:-:-:1      LDS.U.128 czero12, [addr_zero];

// Grab a seed for this thread
// (blkB*gridDimA*256 + blkA*256 + tid) & (1024*256 - 1)
--:-:-:-:1      MOV flag, param_flags;
--:-:-:-:1      LOP.AND.NZ P4, RZ, flag, 0x1;
--:-:-:-:1      MOV dimA, gridDimA;
03:-:-:-:1      ISCADD tbid, blkA, tid, 8;
04:-:-:-:1      XMAD.U16.U16 dimA, blkB, dimA, RZ;
--:-:-:-:1      ISCADD tbid, dimA, tbid, 8;
--:-:-:-:1      LOP.AND seed, tbid, 1x<2048*32 - 1>;
--:-:-:-:1      LEA      Rand0.CC, seed, param_Rand[0],     0x2;
--:-:-:-:1      LEA.HI.X Rand1,    seed, param_Rand[1], RZ, 0x2;
--:-:-:-:1  @P4 LDG.E.CS seed, [Rand];

// tidBX =  tid & 15
// tidAX = (tid & 15) << 3
// tidY = (tid >> 4) & 7
01:-:-:-:1      LOP.AND tidBX, tid,   15;
--:-:-:-:1      SHL     tidAX, tidBX, 3;
--:-:-:-:1      BFE.U32 tidY,  tid,   0x304; // 3 bits at position 4

--:-:-:-:1      MOV lda8,   param_lda8;
--:-:-:-:1      MOV ldb8,   param_ldb8;
--:-:-:-:1      SHR.U32 lda, lda8, 4;
--:-:-:-:1      SHR.U32 ldb, ldb8, 4;
--:-:-:-:1      SHL lda32, lda8, 2;
--:-:-:-:1      SHL ldb32, ldb8, 2;
--:-:-:-:1      MOV ldaz, param_ldaz;
--:-:-:-:1      MOV ldbz, param_ldbz;


// trackA += (blkA*128 + lda*tidY + tidAX) * 2
02:-:-:-:1      ISCADD   txa, blkA, tidAX,  7;
--:-:-:-:1      XMAD.LO2 ta,  lda,  tidY, txa;
08:-:-:-:1      XMAD.LO2 ta,  ldaz, blkZ, ta;
--:-:-:-:1      LEA      track0A0.CC, ta, param_A[0],     0x1;
--:-:-:-:1      LEA.HI.X track0A1,    ta, param_A[1], RZ, 0x1;

--:-:-:-:1      ISETP.LT.AND P5, PT, txa, param_m, PT;

// trackB += (blkB*16 + ldb*tidY + tidBX) * 2
04:-:-:-:1      ISCADD   txb, blkB, tidBX, 4;
--:-:-:-:1      XMAD.LO2 tb,  ldb,  tidY, txb;
08:-:-:-:1      XMAD.LO2 tb,  ldbz, blkZ, tb;
--:-:-:-:1      LEA      track0B0.CC, tb, param_B[0],     0x1;
--:-:-:-:1      LEA.HI.X track0B1,    tb, param_B[1], RZ, 0x1;

--:-:-:-:1      ISETP.LT.AND P6, PT, txb, param_n, PT;

// writeAs = (128*tidY + tidAX) * 4
--:-:-:-:1      ISCADD writeAs, tidY, tidAX, 7;
--:-:-:-:1      SHL    writeAs, writeAs, 2;

// writeBs = (16*tidY + tidBX) * 4
--:-:-:-:1      ISCADD writeBs, tidY, tidBX, 4;
--:-:-:-:1      ISCADD writeBs, writeBs, 4x<128*8>, 2;

// Start the read buffers low
// readAs = (((tid >> 1) & 7) | ((tid & 96) >> 2)) << 4
--:-:-:-:1      LOP.AND tid96,  tid,    96;
--:-:-:-:1      SHR.U32 tid96,  tid96,  2;
--:-:-:-:1      BFE.U32 readAs, tid,    0x301; // 3 bits at position 1
--:-:-:-:1      LOP.OR  readAs, readAs, tid96;
--:-:-:-:1      SHL     readAs, readAs, 4;

// readBs  = (((tid & 0x10) >> 3) | (tid & 1)) << 4;
--:-:-:-:1      LOP.AND tid1,   tid,    1;
--:-:-:-:1      LOP.AND readBs, tid,    0x10;
--:-:-:-:1      SHR.U32 readBs, readBs, 3;
--:-:-:-:1      LOP.OR  readBs, readBs, tid1;
--:-:-:-:1      ISCADD  readBs, readBs, 4x<128*8>, 4;

--:-:-:-:1      IADD   track1A0.CC, track0A0, lda8;
--:-:-:-:1      IADD.X track1A1,    track0A1, RZ;
--:-:-:-:1      IADD   track1B0.CC, track0B0, ldb8;
--:-:-:-:1      IADD.X track1B1,    track0B1, RZ;

--:-:-:-:1      IADD   track2A0.CC, track1A0, lda8;
--:-:-:-:1      IADD.X track2A1,    track1A1, RZ;
--:-:-:-:1      IADD   track2B0.CC, track1B0, ldb8;
--:-:-:-:1      IADD.X track2B1,    track1B1, RZ;

--:-:-:-:1      IADD   track3A0.CC, track2A0, lda8;
--:-:-:-:1      IADD.X track3A1,    track2A1, RZ;
--:-:-:-:1      IADD   track3B0.CC, track2B0, ldb8;
--:-:-:-:1      IADD.X track3B1,    track2B1, RZ;

<ORDERED>
--:-:3:-:1  @P5 LDG.E.CI.128 load0A, [track0A];
--:-:3:-:1  @P6 LDG.E.CI.S16 load0B, [track0B];

--:-:4:-:1  @P5 LDG.E.CI.128 load1A, [track1A];
--:-:4:-:1  @P6 LDG.E.CI.S16 load1B, [track1B];

--:-:5:-:1  @P5 LDG.E.CI.128 load2A, [track2A];
--:-:5:-:1  @P6 LDG.E.CI.S16 load2B, [track2B];

--:-:6:-:1  @P5 LDG.E.CI.128 load3A, [track3A];
--:-:6:-:1  @P6 LDG.E.CI.S16 load3B, [track3B];
</ORDERED>

--:-:-:-:1      ISETP.GE.AND P0, PT, k, 32, PT;
--:-:-:-:1      ISETP.GT.AND P3, PT, k, 32, P5;
--:-:-:-:1      ISETP.GT.AND P4, PT, k, 32, P6;
--:-:-:-:1      IADD k, k, -32;
</SCHEDULE_BLOCK>

04:-:-:-:4      F2F.F32.F16 load0A7, load0A3.H1;
--:-:-:-:4      F2F.F32.F16 load0A6, load0A3.H0;
--:-:-:-:0      IADD   track0A0.CC, track0A0, lda32;
--:-:-:-:4      F2F.F32.F16 load0A5, load0A2.H1;
--:-:1:-:4      F2F.F32.F16 load0A4, load0A2.H0;
--:-:-:-:0      IADD.X track0A1, track0A1, RZ;
--:-:-:-:4      F2F.F32.F16 load0A3, load0A1.H1;
--:-:-:-:4      F2F.F32.F16 load0A2, load0A1.H0;
--:-:-:-:0      IADD   track0B0.CC, track0B0, ldb32;
--:-:-:-:4      F2F.F32.F16 load0A1, load0A0.H1;
--:-:2:-:4      F2F.F32.F16 load0A0, load0A0.H0;
--:-:-:-:0      IADD.X track0B1, track0B1, RZ;
--:-:3:-:1      F2F.F32.F16 load0B, load0B;

01:-:-:-:1      STS.128 [writeAs + 4x<0*(128*8 + 16*8) + 4>], load0A4;
02:-:-:-:1      STS.128 [writeAs + 4x<0*(128*8 + 16*8) + 0>], load0A0;
04:-:-:-:1      STS     [writeBs + 4x<0*(128*8 + 16*8) + 0>], load0B;

--:-:-:-:5      BAR.SYNC 0;

--:-:1:-:1      LDS.U.128 j0Ay0, [readAs + 4x<0*128 + 0*(128*8 + 16*8)>];
--:-:1:-:1      LDS.U.128 j0Bx0, [readBs + 4x<0*16  + 0*(128*8 + 16*8)>];
--:-:2:-:1      LDS.U.128 j1Ay0, [readAs + 4x<1*128 + 0*(128*8 + 16*8)>];
--:-:2:-:1      LDS.U.128 j1Bx0, [readBs + 4x<1*16  + 0*(128*8 + 16*8)>];
--:-:3:-:1  @P3 LDG.E.CI.128 load0A, [track0A];
--:-:3:-:1  @P4 LDG.E.CI.S16 load0B, [track0B];

LOOP:

<CODE>

    our @top;
    our %insert;

    my @cOrder;
    my @swirl = ([0,2],[1,2],[1,0],[0,0]);
    my @y = (0,1);
    foreach my $x (0,2)
    {
        foreach my $y (@y)
        {
            push @cOrder, [$x + $_->[0], $y + $_->[1]] foreach @swirl;
        }
        @y = reverse @y;
    }

    my $out = join '', @top;


    foreach my $k (0 .. 3)
    {
        my $shareBuf = ($k + 1) & 1;
        my $store   = ($k + 1) & 3;
        my $loadBar = $store + 3;
        my $storBar = sprintf '%02x', 1 << ($store + 2);

        %insert =
        (
            j0c11 => "$storBar:-:-:-:1  \@P0 F2F.F32.F16 load${store}A7, load${store}A3.H1;\n",
            j0c15 => "--:-:-:-:1  \@P0 F2F.F32.F16 load${store}A6, load${store}A3.H0;\n",
            j1c3  => "--:-:-:-:1  \@P0 F2F.F32.F16 load${store}A5, load${store}A2.H1;\n",
            j1c7  => "--:-:-:-:1  \@P0 F2F.F32.F16 load${store}A4, load${store}A2.H0;\n",
            j1c11 => "--:-:-:-:1  \@P0 F2F.F32.F16 load${store}A3, load${store}A1.H1;\n",
            j1c15 => "--:-:-:-:1  \@P0 F2F.F32.F16 load${store}A2, load${store}A1.H0;\n",
            j2c3  => "--:-:-:-:1  \@P0 F2F.F32.F16 load${store}A1, load${store}A0.H1;\n",
            j2c7  => "--:-:-:-:1  \@P0 F2F.F32.F16 load${store}A0, load${store}A0.H0;\n",
            j2c11 => "--:-:$loadBar:-:1  \@P0 F2F.F32.F16 load${store}B, load${store}B;\n",

            j2c12 => "--:-:-:-:1  \@P0 IADD   track${store}A0.CC, track${store}A0, lda32;\n",
            j3c1  => "--:-:-:-:1  \@P0 IADD.X track${store}A1,    track${store}A1, RZ;\n",
            j3c3  => "--:-:-:-:1  \@P0 IADD   track${store}B0.CC, track${store}B0, ldb32;\n",
            j3c8  => "--:-:-:-:1  \@P0 IADD.X track${store}B1,    track${store}B1, RZ;\n",

            j3c9  => "$storBar:-:-:-:1  \@P0 STS.128 [writeAs + 4x<$shareBuf*(128*8 + 16*8) + 0>], load${store}A0;\n",
            j4c4 => "--:-:-:-:1  \@P0 STS.128 [writeAs + 4x<$shareBuf*(128*8 + 16*8) + 4>], load${store}A4;\n",
            j4c6 => "--:-:-:-:1  \@P0 STS     [writeBs + 4x<$shareBuf*(128*8 + 16*8) + 0>], load${store}B;\n",

            j5c15 => "--:-:-:-:5  \@P0 BAR.SYNC 0;\n",

            j6c1  => "--:-:$loadBar:-:1  \@P3 LDG.E.CI.128 load${store}A, [track${store}A];\n",
            j6c3  => "--:-:$loadBar:-:1  \@P4 LDG.E.CI.S16 load${store}B, [track${store}B];\n",

            ($k == 3 ?
                (
                j0c4  => "--:-:-:-:1      ISETP.GE.AND P0, PT, k, 32, PT;\n",
                j0c6  => "--:-:-:-:1      ISETP.GT.AND P3, PT, k, 32, P5;\n",
                j0c8  => "--:-:-:-:1      ISETP.GT.AND P4, PT, k, 32, P6;\n",
                j0c10 => "--:-:-:-:1      IADD k, k, -32;\n",

                j7c15 => "--:-:-:Y:5  \@P0 BRA.U LOOP;\n",
                ) : ()
            ),
        );

        foreach my $j (0 .. 7)
        {
            my $rsPred    = $j >= 6 && $k == 3 ? '@P0' : '   ';
            my $barrier   = $j & 1 ? 2 : 1;
            my $loadReg   = ($j + 2) & 3;
            my $compute   = $j & 3;
            my $shareLine = ($j + 2) & 7;
            $shareBuf     = $j >= 6 ? ($k + 1) & 1 : $k & 1;

            $insert{"j${j}c0"} = sprintf "--:-:%d:-:1  %s LDS.U.128 j%dAy0, [readAs + 4x<%d*128 + %d*(128*8 + 16*8)>];\n", $barrier, $rsPred, $loadReg, $shareLine, $shareBuf;
            $insert{"j${j}c2"} = sprintf "--:-:%d:-:1  %s LDS.U.128 j%dBx0, [readBs + 4x<%d*16  + %d*(128*8 + 16*8)>];\n", $barrier, $rsPred, $loadReg, $shareLine, $shareBuf;

            foreach my $c (0 .. 15)
            {
                my ($x,$y) = @{$cOrder[$c]};

                my $ins    = $insert{"j${j}c$c"} || '';

                my $wait   = $c == 0 ? "0$barrier" : '--';

                my $stall  = (split "\n", $ins)[0] =~ /LDS|F2F|I2I|LDG|STS|BAR|BRA/ ? 0 : 1;

                my $yield  = $c == 8 && $stall ? 'Y' : '-';

                my $ctrl   = "$wait:-:-:$yield:$stall";

                $out .= sprintf "%s      FFMA cx%dy%d, j%dBx%d, j%dAy%d, cx%dy%d;\n%s", $ctrl,  $x,$y,  $compute,$x,  $compute,$y,  $x,$y,  $ins;
            }
        }
        $out .= "\n";
    }
    return $out;

</CODE>

//<INCLUDE file="hgemm_common_128x16.sass"/>

<SCHEDULE_BLOCK>

--:-:-:-:1      MOV alpha, param_alpha;
--:-:-:-:1      MOV beta,  param_beta;
--:-:-:-:1      MOV flags, param_flags;

// writeCs = (readAs / 4) * 16 + readBs;
--:-:-:-:1      LOP.AND readAs, readAs, 0x1ff;
--:-:-:-:1      LOP.AND readBs, readBs, 0x1ff;
--:-:-:-:1      ISCADD  writeCs, readAs, readBs, 2;

// tidCX = (tid & 3) << 2
// tidCY = tid >> 2
--:-:-:-:1      LOP.AND tid31, tid,   31;
--:-:-:-:1      LOP.AND tidCX, tid,   3;
--:-:-:-:1      SHL     tidCX, tidCX, 2;
--:-:-:-:1      SHR.U32 tidCY, tid,   2;

// readCs = (tidCY*16 + tidCX)   << 2;
--:-:-:-:1      ISCADD readCs, tidCY, tidCX, 4;
--:-:-:-:1      SHL    readCs, readCs, 2;

// cx = blkB*16 + tidCX;
--:-:-:-:1      ISCADD cx, blkB, tidCX, 4;

// cy = blkA*128 + tidCY*4
--:-:-:-:1      SHL     cy, tidCY, 2;
--:-:-:-:1      ISCADD  cy, blkA,  cy, 7;

// C += (cy*ldc + cx) * 2;
--:-:-:-:1      MOV  ldc,  param_ldc;
--:-:-:-:1      MOV  ldcz, param_ldcz;
--:-:-:-:1      XMAD.LO  ci, cy, ldc, cx, xmad_c;
--:-:-:-:1      XMAD.LO2 ci, ldcz, blkZ,  ci;
--:-:-:-:1      LEA      Cy0.CC, ci, param_C[0],     1;
--:-:-:-:0      LEA.HI.X Cy1,    ci, param_C[1], RZ, 1;

// cx < n
--:-:-:-:1      ISETP.LT.AND P6, PT, cx, param_n, PT;

// beta != 0
--:-:-:-:1      ISETP.NE.AND P5, PT, beta, RZ, P6;

// Random Round flag
--:-:-:-:2      LOP.AND.NZ P4, RZ, flags, 1;

// Apply relu
--:-:-:-:1      LOP.AND.NZ P3, RZ, flags, 2;

--:-:-:-:1      SHL ldc1, ldc, 1;

// Seed the Tausworthe
--:-:-:-:1      LOP.XOR lfsr0, seed, tbid;
--:-:-:-:1      CS2R lfsr1, SR_CLOCKLO;
--:-:-:-:1      CS2R lfsr2, SR_GLOBALTIMERLO;
--:-:-:-:1      LOP.AND clk_shf1, lfsr1, 31;
--:-:-:-:1      LOP.AND clk_shf2, lfsr2, 31;
--:-:-:-:1      LOP.XOR clk_shf1, clk_shf1, tid31;
--:-:-:-:1      LOP.XOR clk_shf2, clk_shf2, tid31;
--:-:-:-:1      SHF.R.U64 lfsr1, lfsr1, clk_shf1, lfsr1;
--:-:-:-:1      SHF.R.U64 lfsr2, lfsr2, clk_shf2, lfsr2;
--:-:-:-:1      LOP.AND tbid, tbid, 1x<2048*32 - 1>;

</SCHEDULE_BLOCK>

--:-:-:-:5      BAR.SYNC 0;

<CODE>

    my $out;
    foreach my $y (0..3)
    {
        $out .= sprintf(
            "--:-:-:-:1      FMUL c0, cx0y%d, alpha;\n" .
            "--:-:-:-:1      FMUL c1, cx1y%d, alpha;\n" .
            "--:-:-:-:1      FMUL c2, cx2y%d, alpha;\n" .
            "--:-:-:-:0      FMUL c3, cx3y%d, alpha;\n",
            ($y) x 4);

        $out .= "--:-:-:-:5      CAL STORE_C;\n\n";
    }
    return $out;

</CODE>

--:-:-:-:6      LEA      Rand0.CC, tbid, param_Rand[0],     0x2;
--:-:-:-:1      LEA.HI.X Rand1,    tbid, param_Rand[1], RZ, 0x2;
--:-:-:-:2      LOP3.LUT seed, lfsr0, lfsr1, lfsr2, 0x96;
--:-:-:-:1  @P4 STG.E.CS [Rand], seed;

--:-:-:-:5      EXIT;


STORE_C:

--:-:-:-:2      ISETP.LT.AND P1, PT, cy, param_m, P5;
--:-:-:Y:b      ISETP.LT.AND P0, PT, cy, param_m, P6;
--:-:-:-:0      IADD cy, cy, 1;

--:-:1:-:1  @P1 LDG.E.64 b0, [Cy];

// Apply relu
--:-:-:-:1  @P3 FMNMX c0, c0, RZ, !PT;
--:-:-:-:1  @P3 FMNMX c1, c1, RZ, !PT;
--:-:-:-:1  @P3 FMNMX c2, c2, RZ, !PT;
--:-:-:-:4  @P3 FMNMX c3, c3, RZ, !PT;

--:-:-:-:1      STS.128 [writeCs], c0;
--:-:5:-:1      LDS.U.128 c0, [readCs];

01:-:1:-:4  @P1 F2F.F32.F16 d3, b1.H1;
--:-:2:-:4  @P1 F2F.F32.F16 d2, b1.H0;
--:-:3:-:4  @P1 F2F.F32.F16 d1, b0.H1;
--:-:4:-:1  @P1 F2F.F32.F16 d0, b0.H0;

11:-:-:-:1  @P1 FFMA c3, d3, beta, c3;
02:-:-:-:1  @P1 FFMA c2, d2, beta, c2;
04:-:-:-:1  @P1 FFMA c1, d1, beta, c1;
08:-:-:-:0  @P1 FFMA c0, d0, beta, c0;

--:-:-:-:5  @P4 BRA.U DO_RANDOM1;

--:-:1:-:4      F2F.F16.F32 c0, c0;
--:-:2:-:4      F2F.F16.F32 c1, c1;
--:-:3:-:4      F2F.F16.F32 c2, c2;
--:-:4:-:1      F2F.F16.F32 c3, c3;

--:-:-:-:5      BRA.U END_ROUND1;

DO_RANDOM1:

--:-:-:-:5      CAL RANDOM_ROUND;

END_ROUND1:

// Pack 2 16 bit values into 32 bit words
03:-:-:-:2      BFI c0, c1, 0x1010, c0;
0c:-:-:-:2      BFI c1, c3, 0x1010, c2;

--:1:-:-:2  @P0 STG.E.64 [Cy], c0;

01:-:-:-:6      IADD   Cy0.CC, Cy0, ldc1;
--:-:-:-:0      IADD.X Cy1,    Cy1, RZ;

--:-:-:-:5      RET;

RANDOM_ROUND:

<SCHEDULE_BLOCK>

// Strip mantissa and leave sign+exponent
--:-:-:-:1      LOP32I.AND exp0, c0, 0xff800000;
--:-:-:-:1      LOP32I.AND exp1, c1, 0xff800000;
--:-:-:-:1      LOP32I.AND exp2, c2, 0xff800000;
--:-:-:-:1      LOP32I.AND exp3, c3, 0xff800000;

// Find the exponent that will shift 32 bits of integer data
// out past the lsb of this number as an fp16
// exp *= 2^-10 * 2^-32  (2^-42)
--:-:-:-:1      FMUL32I exp0, exp0, 0x2a800000;
--:-:-:-:1      FMUL32I exp1, exp1, 0x2a800000;
--:-:-:-:1      FMUL32I exp2, exp2, 0x2a800000;
--:-:-:-:1      FMUL32I exp3, exp3, 0x2a800000;

// lfsr0 = ((lfsr0 & 0xfffffffe) << 12) ^ (((lfsr0 << 13) ^ lfsr0) >> 19);
--:-:-:-:1      LOP32I.AND lfsr0_1, lfsr0, 0xfffffffe;
--:-:-:-:1      SHL lfsr0_1, lfsr0_1, 12;
--:-:-:-:1      SHL lfsr0_2, lfsr0, 13;
--:-:-:-:1      LOP.XOR lfsr0_2, lfsr0_2, lfsr0;
--:-:-:-:1      SHR.U32 lfsr0_2, lfsr0_2, 19;
--:-:-:-:1      LOP.XOR lfsr0, lfsr0_1, lfsr0_2;

// lfsr1 = ((lfsr1 & 0xfffffff8) <<  4) ^ (((lfsr1 << 2)  ^ lfsr1) >> 25);
--:-:-:-:1      LOP32I.AND lfsr1_1, lfsr1, 0xfffffff8;
--:-:-:-:1      SHL lfsr1_1, lfsr1_1, 4;
--:-:-:-:1      SHL lfsr1_2, lfsr1, 2;
--:-:-:-:1      LOP.XOR lfsr1_2, lfsr1_2, lfsr1;
--:-:-:-:1      SHR.U32 lfsr1_2, lfsr1_2, 25;
--:-:-:-:1      LOP.XOR lfsr1, lfsr1_1, lfsr1_2;

// lfsr2 = ((lfsr2 & 0xfffffff0) << 11) ^ (((lfsr2 << 3)  ^ lfsr2) >> 11);
--:-:-:-:1      LOP32I.AND lfsr2_1, lfsr2, 0xfffffff0;
--:-:-:-:1      SHL lfsr2_1, lfsr2_1, 11;
--:-:-:-:1      SHL lfsr2_2, lfsr2, 3;
--:-:-:-:1      LOP.XOR lfsr2_2, lfsr2_2, lfsr2;
--:-:-:-:1      SHR.U32 lfsr2_2, lfsr2_2, 11;
--:-:-:-:1      LOP.XOR lfsr2, lfsr2_1, lfsr2_2;

// rand = lfsr0 ^ lfsr1 ^ lfsr2;
// generate 3 other rotations of this rand
--:-:-:-:1      LOP3.LUT  rand0, lfsr0, lfsr1, lfsr2, 0x96;
--:-:-:-:1      SHF.R.U64 rand1, rand0,  8, rand0;
--:-:-:-:1      SHF.R.U64 rand2, rand0, 16, rand0;
--:-:-:-:0      SHF.R.U64 rand3, rand0, 24, rand0;
//--:-:-:-:1      MOV32I rand0, 0x80000000;
//--:-:-:-:1      MOV32I rand1, 0x80000000;
//--:-:-:-:1      MOV32I rand2, 0x80000000;
//--:-:-:-:1      MOV32I rand3, 0x80000000;
</SCHEDULE_BLOCK>

// Convert rand to float
--:-:1:-:4      I2F.F32.U32.RZ rand0, rand0;
--:-:2:-:4      I2F.F32.U32.RZ rand1, rand1;
--:-:3:-:4      I2F.F32.U32.RZ rand2, rand2;
--:-:4:-:1      I2F.F32.U32.RZ rand3, rand3;

// Scale the random number so msb is one below lsb of fp16
// Add scaled random to number to round
01:-:-:-:1      FFMA.RZ c0, rand0, exp0, c0;
02:-:-:-:1      FFMA.RZ c1, rand1, exp1, c1;
04:-:-:-:1      FFMA.RZ c2, rand2, exp2, c2;
08:-:-:-:0      FFMA.RZ c3, rand3, exp3, c3;

// Truncate number to fp16
--:-:1:-:4      F2F.F16.F32.RZ c0, c0;
--:-:2:-:4      F2F.F16.F32.RZ c1, c1;
--:-:3:-:4      F2F.F16.F32.RZ c2, c2;
--:-:4:-:1      F2F.F16.F32.RZ c3, c3;

--:-:-:-:5      RET;
